{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"In progress.\n",
    "\n",
    "Parag K. Mital, Jan 2016.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "import tensorflow as tf\n",
    "from libs.connections import conv2d, linear\n",
    "from collections import namedtuple\n",
    "from math import sqrt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def residual_network(x, n_outputs,\n",
    "                     activation=tf.nn.relu):\n",
    "    \"\"\"Builds a residual network.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    x : Placeholder\n",
    "        Input to the network\n",
    "    n_outputs : TYPE\n",
    "        Number of outputs of final softmax\n",
    "    activation : Attribute, optional\n",
    "        Nonlinearity to apply after each convolution\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    net : Tensor\n",
    "        Description\n",
    "\n",
    "    Raises\n",
    "    ------\n",
    "    ValueError\n",
    "        If a 2D Tensor is input, the Tensor must be square or else\n",
    "        the network can't be converted to a 4D Tensor.\n",
    "    \"\"\"\n",
    "    # %%\n",
    "    LayerBlock = namedtuple(\n",
    "        'LayerBlock', ['num_repeats', 'num_filters', 'bottleneck_size'])\n",
    "    blocks = [LayerBlock(3, 128, 32),\n",
    "              LayerBlock(3, 256, 64),\n",
    "              LayerBlock(3, 512, 128),\n",
    "              LayerBlock(3, 1024, 256)]\n",
    "\n",
    "    # %%\n",
    "    input_shape = x.get_shape().as_list()\n",
    "    if len(input_shape) == 2:\n",
    "        ndim = int(sqrt(input_shape[1]))\n",
    "        if ndim * ndim != input_shape[1]:\n",
    "            raise ValueError('input_shape should be square')\n",
    "        x = tf.reshape(x, [-1, ndim, ndim, 1])\n",
    "\n",
    "    # %%\n",
    "    # First convolution expands to 64 channels and downsamples\n",
    "    net = conv2d(x, 64, k_h=7, k_w=7,\n",
    "                 name='conv1',\n",
    "                 activation=activation)\n",
    "\n",
    "    # %%\n",
    "    # Max pool and downsampling\n",
    "    net = tf.nn.max_pool(\n",
    "        net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')\n",
    "\n",
    "    # %%\n",
    "    # Setup first chain of resnets\n",
    "    net = conv2d(net, blocks[0].num_filters, k_h=1, k_w=1,\n",
    "                 stride_h=1, stride_w=1, padding='VALID', name='conv2')\n",
    "\n",
    "    # %%\n",
    "    # Loop through all res blocks\n",
    "    for block_i, block in enumerate(blocks):\n",
    "        for repeat_i in range(block.num_repeats):\n",
    "\n",
    "            name = 'block_%d/repeat_%d' % (block_i, repeat_i)\n",
    "            conv = conv2d(net, block.bottleneck_size, k_h=1, k_w=1,\n",
    "                          padding='VALID', stride_h=1, stride_w=1,\n",
    "                          activation=activation,\n",
    "                          name=name + '/conv_in')\n",
    "\n",
    "            conv = conv2d(conv, block.bottleneck_size, k_h=3, k_w=3,\n",
    "                          padding='SAME', stride_h=1, stride_w=1,\n",
    "                          activation=activation,\n",
    "                          name=name + '/conv_bottleneck')\n",
    "\n",
    "            conv = conv2d(conv, block.num_filters, k_h=1, k_w=1,\n",
    "                          padding='VALID', stride_h=1, stride_w=1,\n",
    "                          activation=activation,\n",
    "                          name=name + '/conv_out')\n",
    "\n",
    "            net = conv + net\n",
    "        try:\n",
    "            # upscale to the next block size\n",
    "            next_block = blocks[block_i + 1]\n",
    "            net = conv2d(net, next_block.num_filters, k_h=1, k_w=1,\n",
    "                         padding='SAME', stride_h=1, stride_w=1, bias=False,\n",
    "                         name='block_%d/conv_upscale' % block_i)\n",
    "        except IndexError:\n",
    "            pass\n",
    "\n",
    "    # %%\n",
    "    net = tf.nn.avg_pool(net,\n",
    "                         ksize=[1, net.get_shape().as_list()[1],\n",
    "                                net.get_shape().as_list()[2], 1],\n",
    "                         strides=[1, 1, 1, 1], padding='VALID')\n",
    "    net = tf.reshape(\n",
    "        net,\n",
    "        [-1, net.get_shape().as_list()[1] *\n",
    "         net.get_shape().as_list()[2] *\n",
    "         net.get_shape().as_list()[3]])\n",
    "\n",
    "    net = linear(net, n_outputs, activation=tf.nn.softmax)\n",
    "\n",
    "    # %%\n",
    "    return net\n",
    "\n",
    "\n",
    "def test_mnist():\n",
    "    \"\"\"Test the resnet on MNIST.\"\"\"\n",
    "    import tensorflow.examples.tutorials.mnist.input_data as input_data\n",
    "\n",
    "    mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)\n",
    "    x = tf.placeholder(tf.float32, [None, 784])\n",
    "    y = tf.placeholder(tf.float32, [None, 10])\n",
    "    y_pred = residual_network(x, 10)\n",
    "\n",
    "    # %% Define loss/eval/training functions\n",
    "    cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))\n",
    "    optimizer = tf.train.AdamOptimizer().minimize(cross_entropy)\n",
    "\n",
    "    # %% Monitor accuracy\n",
    "    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))\n",
    "\n",
    "    # %% We now create a new session to actually perform the initialization the\n",
    "    # variables:\n",
    "    sess = tf.Session()\n",
    "    sess.run(tf.initialize_all_variables())\n",
    "\n",
    "    # %% We'll train in minibatches and report accuracy:\n",
    "    batch_size = 50\n",
    "    n_epochs = 5\n",
    "    for epoch_i in range(n_epochs):\n",
    "        # Training\n",
    "        train_accuracy = 0\n",
    "        for batch_i in range(mnist.train.num_examples // batch_size):\n",
    "            batch_xs, batch_ys = mnist.train.next_batch(batch_size)\n",
    "            train_accuracy += sess.run([optimizer, accuracy], feed_dict={\n",
    "                x: batch_xs, y: batch_ys})[1]\n",
    "        train_accuracy /= (mnist.train.num_examples // batch_size)\n",
    "\n",
    "        # Validation\n",
    "        valid_accuracy = 0\n",
    "        for batch_i in range(mnist.validation.num_examples // batch_size):\n",
    "            batch_xs, batch_ys = mnist.validation.next_batch(batch_size)\n",
    "            valid_accuracy += sess.run(accuracy,\n",
    "                                       feed_dict={\n",
    "                                           x: batch_xs,\n",
    "                                           y: batch_ys\n",
    "                                       })\n",
    "        valid_accuracy /= (mnist.validation.num_examples // batch_size)\n",
    "        print('epoch:', epoch_i, ', train:',\n",
    "              train_accuracy, ', valid:', valid_accuracy)\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    test_mnist()"
   ]
  }
 ],
 "metadata": {
  "language": "python"
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
